# -*- coding: utf-8 -*-
'''
Created on 2017年4月21日

@author: ZhuJiahui506
'''

from sklearn.cluster import KMeans
import numpy as np


def get_list_max_count(data_list):
    '''
    获取列表中出现最多的元素的次数
    :param data_list: 数据列表
    :return 最大次数
    '''
    
    max_count = 0
    data_count_dict = dict()
    for each in data_list:
        data_count_dict[each] = 0
    
    for each in data_list:
        data_count_dict[each] += 1
        if data_count_dict[each] > max_count:
            max_count = data_count_dict[each]
    
    return max_count

def simple_means(cluster_num, cluster_data, class_tag):
    kmeans_result = KMeans(n_clusters=cluster_num, max_iter=100).fit(cluster_data)
    #print(kmeans_result.cluster_centers_)  # 聚类中心
    #print(kmeans_result.labels_)  # 聚类标签
    #print(kmeans_result.inertia_)  # 类内距离之和
        
    # 计算类间距离之和
    cluster_distance = 0.0
    for j in range(cluster_num):
        for k in range(j + 1, cluster_num):
            cluster_distance += np.linalg.norm(kmeans_result.cluster_centers_[j] - kmeans_result.cluster_centers_[k])
    #print(cluster_distance)
        
    # 获取每个聚类对应的文档集合
    each_cluster_dict = dict()
    for j in range(cluster_num):
        each_cluster_dict[j] = []
    for j in range(len(kmeans_result.labels_)):
        each_cluster_dict[kmeans_result.labels_[j]].append(j)
        
    # 计算聚类纯度
    correct_count = 0
    for j in range(cluster_num):
        if len(each_cluster_dict[j]) > 0:
            each_cluster_real_tag = []
            for each_index in each_cluster_dict[j]:
                each_cluster_real_tag.append(class_tag[each_index])
            correct_count += get_list_max_count(each_cluster_real_tag)
    purity = np.true_divide(correct_count, len(kmeans_result.labels_))
    #print(purity)
    
    return kmeans_result.inertia_, cluster_distance, purity


if __name__ == '__main__':
    
    cluster_num = 2
    cluster_data = np.array([[1, 1, 1, 0, 0], [0, 0, 0, 3, 3], [1.1, 1, 1, 0, 0], [0.1, 0.1, 0, 3, 3], [1, 1.1, 1, 0, 0]])
    class_tag = [1, 0, 1, 0, 1]
    kmeans_result = KMeans(n_clusters=cluster_num, max_iter=100).fit(cluster_data)
    print(kmeans_result.cluster_centers_)  # 聚类中心
    print(kmeans_result.labels_)  # 聚类标签
    print(kmeans_result.inertia_)  # 类内距离之和
    
    # 计算类间距离之和
    cluster_distance = 0.0
    for j in range(cluster_num):
        for k in range(j + 1, cluster_num):
            cluster_distance += np.linalg.norm(kmeans_result.cluster_centers_[j] - kmeans_result.cluster_centers_[k])
    print(cluster_distance)
    
    # 获取每个聚类对应的文档集合
    each_cluster_dict = dict()
    for j in range(cluster_num):
        each_cluster_dict[j] = []
    for j in range(len(kmeans_result.labels_)):
        each_cluster_dict[kmeans_result.labels_[j]].append(j)
    
    # 计算聚类纯度
    correct_count = 0
    for j in range(cluster_num):
        if len(each_cluster_dict[j]) > 0:
            each_cluster_real_tag = []
            for each_index in each_cluster_dict[j]:
                each_cluster_real_tag.append(class_tag[each_index])
            correct_count += get_list_max_count(each_cluster_real_tag)
    purity = np.true_divide(correct_count, len(kmeans_result.labels_))
    print(purity)
